#!/usr/bin/env python3

import os, sys, json, h5py, numpy as np, torch, torch.nn as nn
from torchvision import transforms
from tqdm import tqdm
from omegaconf import OmegaConf
from accelerate import Accelerator

# ---------- 基本配置 ----------
data_path   = "dataset"
cache_dir   = "Cache"
subj        = 1
save_dir    = f"./subj0{subj}_prior_vae_eye"
device_0    = "cuda:0"
device_1    = "cuda:1"
torch.backends.cuda.matmul.allow_tf32 = True

# ---------- 必须参数 ----------
hidden_dim  = 4096
n_blocks    = 4
blurry_recon= True        # 是否额外输出 blurry 重建
batch_size  = 1           # 每批图片数量，显存大可调大
imsize      = 256         # 输出分辨率
dtype       = torch.float16

# ---------- 读体素数 ----------
with h5py.File(f"{data_path}/betas_all_subj0{subj}_fp32_renorm.hdf5", 'r') as f:
    num_voxels = f['betas'].shape[-1]
print(f"subj{subj} 体素维度: {num_voxels}")

# ---------- 加载 fMRI数据 ----------
voxel_path = os.path.join(save_dir, "predicted_fmri_ep150_step100_repeat5.pt")
test_voxel = torch.load(voxel_path)
print("fMRI 已加载:", test_voxel.shape)

# ---------- 加载模型 ----------
sys.path.append('./mindeye2_src/')
from mindeye2_src.models import  BrainNetwork, PriorNetwork, BrainDiffusionPrior
from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder
import mindeye2_src.utils as utils

class RidgeRegression(torch.nn.Module):
    # make sure to add weight_decay when initializing optimizer to enable regularization
    def __init__(self, input_sizes, out_features): 
        super(RidgeRegression, self).__init__()
        self.out_features = out_features
        self.linears = torch.nn.ModuleList([
                torch.nn.Linear(input_size, out_features) for input_size in input_sizes
            ])
    def forward(self, x, subj_idx):
        out = self.linears[subj_idx](x).unsqueeze(1)
        return out

class MindEyeModule(nn.Module):
    def __init__(self):
        super().__init__()

model = MindEyeModule()
model.ridge   = RidgeRegression([num_voxels], out_features=hidden_dim)
model.backbone= BrainNetwork(
    h=hidden_dim, in_dim=hidden_dim, seq_len=1, n_blocks=n_blocks,
    clip_size=1664, out_dim=1664*256, blurry_recon=blurry_recon, clip_scale=1.0
)

prior_net = PriorNetwork(
    dim=1664, depth=6, dim_head=52, heads=32,
    causal=False, num_tokens=256, learned_query_mode="pos_emb"
)
model.diffusion_prior = BrainDiffusionPrior(
    net=prior_net, image_embed_dim=1664,
    condition_on_text_encodings=False,
    timesteps=100, cond_drop_prob=0.2
)

ckpt_path = f"Cache/final_subj0{subj}_pretrained_40sess_24bs/last.pth"
ckpt = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"])
model.to(device_0).eval()
print("MindEye 模型已加载")

# ---------- VAE for blurry ----------
if blurry_recon:
    from diffusers import AutoencoderKL
    autoenc = AutoencoderKL(
        down_block_types=['DownEncoderBlock2D']*4,
        up_block_types  =['UpDecoderBlock2D']*4,
        block_out_channels=[128,256,512,512],
        layers_per_block=2, sample_size=256
    )
    autoenc.load_state_dict(torch.load(f"{cache_dir}/sd_image_var_autoenc.pth"))
    autoenc.eval().requires_grad_(False).to(device_0)

# ---------- unCLIP ----------
sys.path.append('./mindeye2_src/generative_models')
config = OmegaConf.load("./mindeye2_src/generative_models/configs/unclip6.yaml")
cfg = OmegaConf.to_container(config, resolve=True)["model"]["params"]
from generative_models.sgm.models.diffusion import DiffusionEngine
diffusion_engine = DiffusionEngine(
    network_config   = cfg["network_config"],
    denoiser_config  = cfg["denoiser_config"],
    first_stage_config=cfg["first_stage_config"],
    conditioner_config=cfg["conditioner_config"],
    sampler_config   = cfg["sampler_config"],
    scale_factor     = cfg["scale_factor"],
    disable_first_stage_autocast=cfg["disable_first_stage_autocast"]
)
diffusion_engine.load_state_dict(torch.load(f"{cache_dir}/unclip6_epoch0_step110000.ckpt")["state_dict"])
diffusion_engine.eval().requires_grad_(False).to(device_1)
diffusion_engine.sampler.device = device_1   # "cuda:1"

# 构建 conditioner 占位符
batch = {"jpg": torch.randn(1,3,1,1).to(device_1),
         "original_size_as_tuple": torch.ones(1,2).to(device_1)*768,
         "crop_coords_top_left": torch.zeros(1,2).to(device_1)}
vector_suffix = diffusion_engine.conditioner(batch)["vector"].to(device_1)

# ---------- 推理 ----------
model.to(device_0).eval()
all_recons, all_blurryrecons = None, None
n_imgs = len(test_voxel)

with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
    for start in tqdm(range(0, n_imgs, batch_size)):
        end   = min(start+batch_size, n_imgs)
        voxel = test_voxel[start:end].to(device_0)   # (B,N)

        # ridge + backbone（3 次取平均）
        clip_voxels, backbone, blurry_enc = 0, 0, 0
        for rep in range(3):
            v = model.ridge(voxel, 0)
            b, c, blur = model.backbone(v)
            clip_voxels += c
            backbone    += b
            if rep==0: blurry_enc = blur[0]
            else:      blurry_enc += blur[0]
        clip_voxels /= 3
        backbone    /= 3
        blurry_enc  /= 3

        # diffusion prior → unCLIP
        prior_out = model.diffusion_prior.p_sample_loop(
            backbone.shape, text_cond={"text_embed": backbone},
            cond_scale=1.0, timesteps=20
        ).to(device_1)

        # 重建图像
        samples = utils.unclip_recon(prior_out, diffusion_engine, vector_suffix,
                                     num_samples=1)   # (B,3,256,256)
        if all_recons is None:
            all_recons = samples.cpu()
        else:
            all_recons = torch.cat([all_recons, samples.cpu()])

        # blurry 版本
        if blurry_recon:
            blur_img = (autoenc.decode(blurry_enc/0.18215).sample/2+0.5).clamp(0,1)
            blur_img = transforms.Resize((imsize,imsize))(blur_img).cpu()
            if all_blurryrecons is None:
                all_blurryrecons = blur_img
            else:
                all_blurryrecons = torch.cat([all_blurryrecons, blur_img])

# ---------- 保存 ----------
torch.save(all_recons, os.path.join(save_dir, "all_recons.pt"))
if blurry_recon:
    torch.save(all_blurryrecons, os.path.join(save_dir, "all_blurryrecons.pt"))
# ---------- 保存 ----------
os.makedirs(save_dir, exist_ok=True)

# 重建图像
torch.save(all_recons,
           os.path.join(save_dir, "all_recons.pt"))

# blurry 重建（如启用）
if blurry_recon:
    torch.save(all_blurryrecons,
               os.path.join(save_dir, "all_blurryrecons.pt"))
print("✅ 重建完成，已保存:", save_dir)